package com.lyon.demo.netty.client;

import com.lyon.demo.common.spi.annotation.LyonSpi;
import com.lyon.demo.netty.core.InFlightRequests;
import com.lyon.demo.netty.core.NettyTransport;
import com.lyon.demo.netty.core.ResponseInvocation;
import com.lyon.demo.netty.core.bytebuf.CommandDecoder;
import com.lyon.demo.netty.core.bytebuf.CommandEncoder;
import com.lyon.demo.rpc.api.core.CommonProtocol;
import com.lyon.demo.rpc.api.endpoint.Transport;
import com.lyon.demo.rpc.api.endpoint.TransportClient;
import com.lyon.demo.rpc.api.naming.NamingClient;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeoutException;

/**
 * @author Lyon
 */
@Slf4j
@LyonSpi(value = CommonProtocol.NETTY)
@ChannelHandler.Sharable
public class NettyClient implements TransportClient, NamingClient {

    private NioEventLoopGroup nioEventLoopGroup;
    private Bootstrap bootstrap;
    private final List<Channel> channels = new ArrayList<>();
    private final Map<SocketAddress, Channel> channelMap = new ConcurrentHashMap<>();
    private final InFlightRequests inFlightRequests = new InFlightRequests();

    @Override
    public Transport createTransport(SocketAddress socketAddress, long connectTimeout) throws InterruptedException {
        final Channel channel = createChannel(socketAddress, connectTimeout);
        return new NettyTransport(channel, inFlightRequests);
    }

    @SneakyThrows
    private Channel createChannel(SocketAddress socketAddress, long connectTimeout) throws InterruptedException {
        if (Objects.isNull(nioEventLoopGroup)) {
            nioEventLoopGroup = new NioEventLoopGroup();
        }
        if (Objects.isNull(bootstrap)) {
            newBootstrap();
        }
        Channel channel = channelMap.get(socketAddress);
        if (channel != null && channel.isActive()) {
            return channel;
        }
        final ChannelFuture channelFuture = bootstrap.connect(socketAddress);
        channel = channelFuture.channel();
        if (!channelFuture.await(connectTimeout)) {
            throw new TimeoutException("timout..");
        }
        if(channel==null || !channel.isActive()) {
           throw new IllegalStateException();
        }
        channelMap.put(socketAddress, channel);
        channels.add(channel);
        return channel;
}

    private void newBootstrap() {
        this.bootstrap = new Bootstrap().group(nioEventLoopGroup)
                .channel(NioSocketChannel.class)
                .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                .handler(newChannelHandler())
        ;
    }

    private ChannelHandler newChannelHandler() {
        return new ChannelInitializer<>() {
            @Override
            protected void initChannel(Channel channel) throws Exception {
                channel.pipeline()
                        .addLast(new CommandDecoder())
                        .addLast(new ResponseInvocation(inFlightRequests))
                        .addLast(new CommandEncoder());
            }
        };
    }


}
